{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Experiment tracking is essential in machine learning because it enables data scientists and researchers to effectively manage and reproduce their experiments. By tracking various aspects of an experiment, such as hyperparameters, model architecture, and training data, it becomes easier to understand and interpret the results. Experiment tracking also allows for better collaboration and knowledge sharing among team members, as it provides a centralized repository of experiments and their associated metadata. Additionally, tracking experiments helps in debugging and troubleshooting, as it allows for the identification of specific settings or conditions that led to successful or unsuccessful outcomes. Overall, experiment tracking plays a crucial role in ensuring transparency, reproducibility, and continuous improvement in machine learning workflows.\n",
    "\n",
    "Now let's see how we can get all these benefits for free with PyTorch Tabular using Weights & Biases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "import random\n",
    "from pytorch_tabular.utils import load_covertype_dataset, print_metrics\n",
    "import pandas as pd\n",
    "import wandb\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmanujosephv\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "data, cat_col_names, num_col_names, target_col = load_covertype_dataset()\n",
    "train, test = train_test_split(data, random_state=42)\n",
    "train, val = train_test_split(train, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Importing the Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import (\n",
    "    CategoryEmbeddingModelConfig,\n",
    "    FTTransformerConfig,\n",
    "    TabNetModelConfig,\n",
    "    GANDALFConfig,\n",
    ")\n",
    "from pytorch_tabular.config import (\n",
    "    DataConfig,\n",
    "    OptimizerConfig,\n",
    "    TrainerConfig,\n",
    "    ExperimentConfig,\n",
    ")\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Common Configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[\n",
    "        target_col\n",
    "    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    auto_lr_find=True,  # Runs the LRFinder to automatically derive a learning rate\n",
    "    batch_size=1024,\n",
    "    max_epochs=100,\n",
    "    early_stopping=\"valid_loss\",  # Monitor valid_loss for early stopping\n",
    "    early_stopping_mode=\"min\",  # Set the mode as min because for val_loss, lower is better\n",
    "    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating\n",
    "    checkpoints=\"valid_loss\",  # Save best checkpoint monitoring val_loss\n",
    "    load_best=True,  # After training, load the best checkpoint\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "head_config = LinearHeadConfig(\n",
    "    layers=\"\",  # No additional layer in head, just a mapping layer to output_dim\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\",\n",
    ").__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "EXP_PROJECT_NAME = \"pytorch-tabular-covertype\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Category Embedding Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "Collapsed": "false",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b45f3f9e6f841d1bc0bbc445acbf7e1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011112782611356427, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240106_114035-hvalkv16</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/hvalkv16' target=\"_blank\">CategoryEmbeddingModel_5</a></strong> to <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/hvalkv16' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/hvalkv16</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81aed7a908894178936247ad9ae7dd8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │  823 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │    896 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │  3.6 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  823 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │    896 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │  3.6 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 827 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 827 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 3                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 827 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 827 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b5b00d2f4454cf18a0be3a7dc8b32a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f14867a4850>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"1024-512-512\",  # Number of nodes in each layer\n",
    "    activation=\"LeakyReLU\",  # Activation between each layers\n",
    "    learning_rate=1e-3,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "experiment_config = ExperimentConfig(\n",
    "    project_name=EXP_PROJECT_NAME,\n",
    "    run_name=\"CategoryEmbeddingModel\",\n",
    "    exp_watch=\"gradients\",\n",
    "    log_target=\"wandb\",\n",
    "    log_logits=True,\n",
    ")\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    experiment_config=experiment_config,\n",
    "    verbose=False,\n",
    "    suppress_lightning_logger=True,\n",
    ")\n",
    "tabular_model.fit(train=train, validation=val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33e42b4493a8415aa2565816ba30b64b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9159672856330872     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.21389885246753693    </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9159672856330872    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.21389885246753693   \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = tabular_model.evaluate(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c4c355484ec24619bc0fb1c5a75de5b9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███</td></tr><tr><td>test_accuracy</td><td>▁</td></tr><tr><td>test_loss</td><td>▁</td></tr><tr><td>train_accuracy</td><td>▁▃▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████████████</td></tr><tr><td>train_loss</td><td>█▇▆▅▅▅▄▄▃▄▃▄▄▃▂▃▃▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>valid_accuracy</td><td>▁▂▃▃▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█▇██▇███████▇██▇</td></tr><tr><td>valid_loss</td><td>█▇▆▆▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▂▁▁▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>52</td></tr><tr><td>test_accuracy</td><td>0.91597</td></tr><tr><td>test_loss</td><td>0.2139</td></tr><tr><td>train_accuracy</td><td>0.91938</td></tr><tr><td>train_loss</td><td>0.20502</td></tr><tr><td>trainer/global_step</td><td>16640</td></tr><tr><td>valid_accuracy</td><td>0.89782</td></tr><tr><td>valid_loss</td><td>0.24614</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">CategoryEmbeddingModel_5</strong> at: <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/hvalkv16' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/hvalkv16</a><br/>Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20240106_114035-hvalkv16/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Although the experiment should finish automatically, it's safer to call it explicitly before running a new experiment\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## FT Transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d2eb20a8a19e4d64bc4b244d30fec4eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011112755477531917, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240106_120910-k2qdphzr</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/k2qdphzr' target=\"_blank\">FTTransformer_1</a></strong> to <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/k2qdphzr' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/k2qdphzr</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "876a75dca2b2410cb720d741bc5a31e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                  </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ FTTransformerBackbone │ 86.5 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding2dLayer      │  2.2 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ LinearHead            │    231 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss      │      0 │\n",
       "└───┴──────────────────┴───────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                 \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ FTTransformerBackbone │ 86.5 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding2dLayer      │  2.2 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ LinearHead            │    231 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss      │      0 │\n",
       "└───┴──────────────────┴───────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 89.0 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 89.0 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 89.0 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 89.0 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3f882aa5ffe4f9e914aa91615e08a5d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f1486720910>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = FTTransformerConfig(\n",
    "    task=\"classification\",\n",
    "    num_attn_blocks=3,\n",
    "    num_heads=4,\n",
    "    learning_rate=1e-3,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "experiment_config = ExperimentConfig(\n",
    "    project_name=EXP_PROJECT_NAME,\n",
    "    run_name=\"FTTransformer\",\n",
    "    exp_watch=\"gradients\",\n",
    "    log_target=\"wandb\",\n",
    "    log_logits=True,\n",
    ")\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    experiment_config=experiment_config,\n",
    "    verbose=False,\n",
    "    suppress_lightning_logger=True,\n",
    ")\n",
    "tabular_model.fit(train=train, validation=val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09f93608e28b4cba93b617d63701c097",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9120706915855408     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.21334321796894073    </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9120706915855408    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.21334321796894073   \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = tabular_model.evaluate(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a16663f04d4d4bc19b7021f259e2bb4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.010 MB of 0.010 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>test_accuracy</td><td>▁</td></tr><tr><td>test_loss</td><td>▁</td></tr><tr><td>train_accuracy</td><td>▁▃▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████</td></tr><tr><td>train_loss</td><td>█▇▆▅▆▄▄▄▄▄▃▄▃▃▂▃▃▃▂▂▂▁▃▃▃▄▃▂▂▄▂▂▁▂▁▃▂▃▂▁</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>valid_accuracy</td><td>▁▃▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████</td></tr><tr><td>valid_loss</td><td>█▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>47</td></tr><tr><td>test_accuracy</td><td>0.91207</td></tr><tr><td>test_loss</td><td>0.21334</td></tr><tr><td>train_accuracy</td><td>0.88804</td></tr><tr><td>train_loss</td><td>0.23555</td></tr><tr><td>trainer/global_step</td><td>15040</td></tr><tr><td>valid_accuracy</td><td>0.91161</td></tr><tr><td>valid_loss</td><td>0.21692</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">FTTransformer_1</strong> at: <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/k2qdphzr' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/k2qdphzr</a><br/>Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20240106_120910-k2qdphzr/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## GANDALF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf4bdaca4db841a4a4d8444f2ed23e4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.01111247184453532, max=1.0)…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240106_123420-9kg2s1qg</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/9kg2s1qg' target=\"_blank\">GANDALF_1</a></strong> to <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/9kg2s1qg' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/9kg2s1qg</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51b72498eaad476ca68ab2c9b7c9ff13",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ GANDALFBackbone  │ 70.7 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer │    896 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ Sequential       │    252 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ GANDALFBackbone  │ 70.7 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │    896 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ Sequential       │    252 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 71.9 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 71.9 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 71.9 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 71.9 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d7b2688f2e24bc88abe3fb144001006",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f14866c8a50>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = GANDALFConfig(\n",
    "    task=\"classification\",\n",
    "    gflu_stages=10,\n",
    "    learning_rate=1e-3,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "experiment_config = ExperimentConfig(\n",
    "    project_name=EXP_PROJECT_NAME,\n",
    "    run_name=\"GANDALF\",\n",
    "    exp_watch=\"gradients\",\n",
    "    log_target=\"wandb\",\n",
    "    log_logits=True,\n",
    ")\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    experiment_config=experiment_config,\n",
    "    verbose=False,\n",
    "    suppress_lightning_logger=True,\n",
    ")\n",
    "tabular_model.fit(train=train, validation=val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "284d348e5ca449588a89c7925b85db00",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8669493794441223     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.32519233226776123    </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8669493794441223    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.32519233226776123   \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = tabular_model.evaluate(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dbd8004e3a5e4ee6a496c0ec3c0c5521",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█</td></tr><tr><td>test_accuracy</td><td>▁</td></tr><tr><td>test_loss</td><td>▁</td></tr><tr><td>train_accuracy</td><td>▁▆▇███████</td></tr><tr><td>train_loss</td><td>█▇▄▃▃▂▁▃▂▂▂▂▃▁▂▃▁▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▂▃▂▁▂▁▂▄</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████</td></tr><tr><td>valid_accuracy</td><td>▁▆▆▆▇█▇▇██</td></tr><tr><td>valid_loss</td><td>█▃▂▂▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>10</td></tr><tr><td>test_accuracy</td><td>0.86695</td></tr><tr><td>test_loss</td><td>0.32519</td></tr><tr><td>train_accuracy</td><td>0.86358</td></tr><tr><td>train_loss</td><td>0.42013</td></tr><tr><td>trainer/global_step</td><td>3200</td></tr><tr><td>valid_accuracy</td><td>0.86707</td></tr><tr><td>valid_loss</td><td>0.32778</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">GANDALF_1</strong> at: <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/9kg2s1qg' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/9kg2s1qg</a><br/>Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20240106_123420-9kg2s1qg/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## TabNet Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "133a141208fb4b23bb0eb7ed8167ee00",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011112859611037291, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240106_124017-iw6q00dk</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/iw6q00dk' target=\"_blank\">TabNet_1</a></strong> to <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/iw6q00dk' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/iw6q00dk</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "494f071e8dee4ae2ad3b4bf0f9515854",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _embedding_layer │ Identity         │      0 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _backbone        │ TabNetBackbone   │ 29.3 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ Identity         │      0 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Identity         │      0 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ TabNetBackbone   │ 29.3 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ Identity         │      0 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 29.3 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 29.3 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 29.3 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 29.3 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73c51483922c4e0bb54abe0fbf32fc57",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f1487491310>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = TabNetModelConfig(\n",
    "    task=\"classification\",\n",
    "    learning_rate=1e-5,\n",
    "    n_d=16,\n",
    "    n_a=16,\n",
    "    n_steps=4,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "experiment_config = ExperimentConfig(\n",
    "    project_name=EXP_PROJECT_NAME,\n",
    "    run_name=\"TabNet\",\n",
    "    exp_watch=\"gradients\",\n",
    "    log_target=\"wandb\",\n",
    "    log_logits=True,\n",
    ")\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    experiment_config=experiment_config,\n",
    "    verbose=False,\n",
    "    suppress_lightning_logger=True,\n",
    ")\n",
    "tabular_model.fit(train=train, validation=val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18e3eb46e16f48318079f2e1377a1caa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7186701893806458     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6771128177642822     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7186701893806458    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6771128177642822    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = tabular_model.evaluate(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "63b7bf42b15c4d0a8273caf127c9753d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇█</td></tr><tr><td>test_accuracy</td><td>▁</td></tr><tr><td>test_loss</td><td>▁</td></tr><tr><td>train_accuracy</td><td>▁▄▅▇█▆</td></tr><tr><td>train_loss</td><td>▇▄▃▂▃▃▃▃▁▄█▄▅▄▄▃▂▂▂▄▃▂▂▂▃▂▂▂▂▂▂▂▂▁▄▄▅▃</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████</td></tr><tr><td>valid_accuracy</td><td>█▇▁█▅▄</td></tr><tr><td>valid_loss</td><td>▁▃█▁▂█</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>6</td></tr><tr><td>test_accuracy</td><td>0.71867</td></tr><tr><td>test_loss</td><td>0.67711</td></tr><tr><td>train_accuracy</td><td>0.70094</td></tr><tr><td>train_loss</td><td>0.7142</td></tr><tr><td>trainer/global_step</td><td>1920</td></tr><tr><td>valid_accuracy</td><td>0.65582</td></tr><tr><td>valid_loss</td><td>0.95996</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">TabNet_1</strong> at: <a href='https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/iw6q00dk' target=\"_blank\">https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/iw6q00dk</a><br/>Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20240106_124017-iw6q00dk/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accessing the Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can access the runs @ https://wandb.ai/manujosephv/pytorch-tabular-covertype/\n",
    "![](imgs/wandb_preview_1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also inspect gradient flows in each component of the model for debugging purposes.\n",
    "![](imgs/wandb_preview_2.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
